from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
from CDTools.tools.cmath import *
import torch as t
from RIP_tools import *

#
# This section runs the numerical experiment designed
# to test the robustness of RIP to missing data.
#

probe_maxk = 128

# Set GPU to false if you don't have a GPU
GPU = True

# This sets the size of the array we simulate the probe on
field_size = probe_maxk * 2
# This sets the focal spot size of our BLR probe
probe_fov_radius = (field_size * 4)//10

# This defines that we do our error comparison within the central
# half of the image.
error_check_padding = 1/ 4


# This generates an ideal BLR probe with the given parameters
nearfield = generate_blr_probe([field_size,field_size],
                               probe_fov_radius, 
                               probe_maxk)


# This defines a reconstruction with R = 0.4
#resolution = int(probe_maxk * 0.8 / 2) * 2
# This defines a reconstruction with R = 0.3
resolutions = np.arange(int(probe_maxk * 0.75 / 2) * 2,int(probe_maxk * 0.6 / 2) * 2,-6)
print(resolutions / (2*probe_maxk))
# We store the final loss and error from each reconstruction
losses = []
errors = []

# We set up a sweep over different total photon counts in the image
missing_band_radii = range(0,30,2)


def generate_mask(mbr):
    shape = np.array(nearfield.shape[:-1])
    mask = t.ones(tuple(shape))
    center = shape // 2
    mask[center[0]-mbr:center[0]+mbr,:] = 0
    mask[:,center[1]-mbr:center[1]+mbr] = 0
    return mask
    

for resolution in resolutions:
    print('Simulation Resolution',resolution)
    grid_losses = []
    grid_errors = []
    for mbr in missing_band_radii:
        print('Simulating missing band radius', mbr)
        loop_losses = []
        loop_errors = []
        mask = generate_mask(mbr)
        
        for attempt in range(1000):
            print('Running attempt',attempt, end='\r')
            # We generate a random image
            real = np.random.randn(resolution, resolution)
            imag = np.random.randn(resolution, resolution)
            original_image = real + 1j * imag
            
            original_image = complex_to_torch(original_image).to(dtype=t.float32)
            
            # All this does is upsample the probe in real space to avoid
            # aliasing in the probe-object interaction
            simulation_nearfield = expand_probe(nearfield, original_image)
            # Now we perform the simulation
            exit_wave = interact(original_image, simulation_nearfield)
            diffraction = propagate(exit_wave)
            intensities = measure(diffraction, nearfield.shape)
            
            # We normalize the diffraction pattern
            intensities /= t.sum(intensities)
            intensities *= 10 * 512**2
            
            intensities *= mask
            
            
            # This attempts a reconstruction with a randomized initialization
            retrieval, loss = reconstruct(intensities, nearfield, resolution, 0.4, 1000, GPU=GPU, schedule=True, mask=mask.to(t.bool))
            ret = torch_to_complex(retrieval)
            im = torch_to_complex(original_image)

            # We calculate the overall RMS error in the central region
            padding = int(resolution * error_check_padding)
            error, originals, comparisons = compare_images(ret, im, np.s_[padding:-padding,padding:-padding])
            
            # And save the RMS image error and diffraction loss
            loop_losses.append(loss[-1])
            loop_errors.append(error)
            #print(error)

        grid_errors.append(loop_errors)
        grid_losses.append(loop_losses)

    errors.append(grid_errors)
    losses.append(grid_losses)
    
    
errors = np.array(errors)
losses = np.array(losses)

results = {'resolutions': resolutions, 'missing band radii': missing_band_radii, 'losses':losses, 'errors':errors, 'kp': probe_maxk}
with open('data/missing data trial.pickle', 'wb') as f:
    pickle.dump(results, f)
